
include("../src/qmcmary.jl")
using ..qmcmary

using ReverseDiff: gradient, jacobian

function aexp(x)
    dt = 0.1
    lams = acosh.(exp.(x*dt/2))
    return lams
end

"""
计算B矩阵的导数
"""
function grad_bmat()
    Uz = [1.0, 1.0]
    #bmat = bmat_Ising(rand(2), U=Uv)
    cfg = [1, -1]
    ehk = lattice_chain(2, -1.0)
    ehk = kron([1 0; 0 1], ehk)
    println(bmat_IsingAD(ehk, cfg, [1.0 1.0], zeros(2), zeros(2), ones(2)))
    grad = gradient(Uz -> bmat_IsingAD(ehk, cfg, [1.0 1.0], zeros(2), zeros(2), ones(2))[1][1, 1], Uz)
    #grad = gradient(xs -> aexp(xs)[1], (Uz))
    println(grad)
end


"""theta到"""
function grad_theta(thes::Float64, phis::Float64)
    fnx(t, p) = cos.(t).*sin.(p)
    fny(t, p) = sin.(t).*sin.(p)
    fnz(t, p) = cos.(p)
    (pnxpt, pnxpp) = gradient(fnx, ([thes], [phis]))
    (pnypt, pnypp) = gradient(fny, ([thes], [phis]))
    (pnzpt, pnzpp) = gradient(fnz, ([thes], [phis]))
    return pnxpt[1], pnxpp[1], pnypt[1], pnypp[1], pnzpt[1], pnzpp[1]
end


function chain_test()
    thes = [0.25*pi, 0.0]
    phis = [0.17*pi, 0.2]
    cfg = [1, -1]
    nx = cos.(thes).*sin.(phis)
    ny = sin.(thes).*sin.(phis)
    nz = cos.(phis)
    ehk = lattice_chain(2, -1.0)
    ehk = kron([1 0; 0 1], ehk)
    eV = bmat_IsingAD(ehk, cfg, [1.0, 1.0], nx, ny, nz)
    println(eV)
    #eV2 = bmat_IsingND(cfg, [1.0, 1.0], nx, ny, nz)
    #println(eV2)
    #println(isapprox.(eV, eV2, atol=1e-12))
    #return
    #println(bmat_IsingAD(cfg, 0, 0, cos.(phis)))
    fetp(t, p) = bmat_IsingAD(ehk, cfg, [1.0, 1.0], sin.(p).*cos.(t), sin.(p).*sin.(t), cos.(p))[2]
    #([∂f_1/∂t_1, ∂f_1/∂t_2], [∂f_1/∂p_1, ∂f_1/∂p_2])
    jpept, jpepp = jacobian(fetp, (thes, phis))
    println(jpepp)
    println(size(jpepp))
    jpepp = reshape(jpepp, 4, 4, 2)
    println(jpepp[:, :, 1])
    println(jpepp[:, :, 2])
    #
    fenx(nx, ny, nz) = bmat_IsingAD(ehk, cfg, [1.0, 1.0], nx, ny, nz)[2][1, 3]
    nx = cos.(thes).*sin.(phis)
    ny = sin.(thes).*sin.(phis)
    nz = cos.(phis)
    #∂f_x/∂t_y = (∂f_x / ∂nx_y) ∂nx_y/∂theta_y
    #因为nx只有对应的theta和phi有数值
    #[∂f_1/∂nx_1, ∂f_1/∂nx_2]
    (pexpnx, pexpny, pexpnz) = gradient(fenx, (nx, ny, nz))
    #pnxpt1 = ∂nx_1/∂theta_1
    pnxpt1, pnxpp1, pnypt1, pnypp1, pnzpt1, pnzpp1 = grad_theta(thes[1], phis[1])
    #pnxpt2 = ∂nx_2/∂theta_2
    pnxpt2, pnxpp2, pnypt2, pnypp2, pnzpt2, pnzpp2 = grad_theta(thes[2], phis[2])
    #chain rule计算
    #[∂f_1/∂nx_1*∂nx_1/∂theta_1, ∂nx_1/∂theta_2]
    println([pexpnz[1]*pnzpp1+pexpny[1]*pnypp1+pexpnx[1]*pnxpp1, pexpnz[2]*pnzpp2+pexpny[2]*pnypp2+pexpnx[2]*pnxpp2])
end


function value_test()
    thes = [0.0, 0.0]
    phis = [0.190, 0.39]
    cfg = [1, -1]
    nx = cos.(thes).*sin.(phis)
    ny = sin.(thes).*sin.(phis)
    nz = cos.(phis)
    ehk = lattice_chain(2, -1.0)
    ehk = kron([1 0; 0 1], ehk)
    eV = bmat_IsingAD(ehk, cfg, [1.0, 1.0], nx, ny, nz)
    eV2 = bmat_IsingND(ehk, cfg, [1.0, 1.0], nx, ny, nz)
    println(eV)
    println(eV2)
    println(isapprox.(eV[1]+eV[2]*im, eV2, atol=1e-12))
    #
    fetp(t, p) = bmat_IsingAD(ehk, cfg, [1.0, 1.0], sin.(p), sin.(t), cos.(p))[1][1, 1]
    (pbpt, pbpp) = gradient(fetp, (thes, phis))
    println(pbpp[1])
    #
    phis = [0.191, 0.39]
    nx = cos.(thes).*sin.(phis)
    ny = sin.(thes).*sin.(phis)
    nz = cos.(phis)
    eV3 = bmat_IsingND(ehk, cfg, [1.0, 1.0], nx, ny, nz)
    println(eV3[1, 1]-eV2[1][1, 1])
    return
end

grad_bmat()

chain_test()

value_test()
